package com.carol.bigdata.task.model.algo

import java.io.File

import org.apache.spark.ml.classification.{Classifier, OneVsRest}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.DataFrame


trait ModelTrait {
    // 未实现函数,子类集成必须实现,已实现函数子类可以直接使用

    def init(params: Map[String, Any]): Unit

    // 构建pipeline模型
    def buildPipeline(featuresCol: String = "features",
                      labelCol: String = "label",
                      rawPredictionCol: String = "rawPrediction",
                      predictionCol: String = "prediction",
                      objective: String = "binary",
                      numClass: Int = 2): Pipeline


    // 交叉验证
    def buildValidator(pipeline: Pipeline,
                       seed: Int = 1,
                       numFolds: Int = 2,
                       parallelNum: Int = 2,
                       objective: String = "binary"): CrossValidator


    // 基于二分类器构建任意分类的PipeLine
    def buildOvrTree(binaryModel: Classifier[_, _, _],
                     featuresCol: String = "features",
                     labelCol: String = "label",
                     rawPredictionCol: String = "rawPrediction",
                     predictionCol: String = "prediction",
                     objective: String = "binary"): OneVsRest = {

        // 构造多分类器
        val ovrTree: OneVsRest = new OneVsRest()
            .setClassifier(binaryModel)
            .setFeaturesCol(featuresCol)
            .setLabelCol(labelCol)
            .setPredictionCol(predictionCol)

        ovrTree

    }


    // 基于pipeline和gridBuilder构建交叉验证器
    def buildValidatorFromGrid(pipeline: Pipeline,
                               gridBuilder: ParamGridBuilder,
                               seed: Int = 1,
                               numFolds: Int = 2,
                               parallelNum: Int = 2,
                               objective: String = "binary"): CrossValidator = {
        // 交叉验证
        val paramGrid: Array[ParamMap] = gridBuilder.build()
        // 评估器
        val evaluator = {
            if (objective == "binary")
                new BinaryClassificationEvaluator
            else new MulticlassClassificationEvaluator
        }
        // 交叉验证模型
        val cv: CrossValidator = new CrossValidator()
            .setEstimator(pipeline)
            .setEstimatorParamMaps(paramGrid)
            .setEvaluator(evaluator)
            .setSeed(seed)
            .setNumFolds(numFolds)          // Use 3+ in practice
            .setParallelism(parallelNum)    // Evaluate up to 2 parameter settings in parallel
            //.setCollectSubModels(true)    // specified to collect all validated models
        cv
    }


    // 获取最优超参
    def getBestParams(crossModel: CrossValidatorModel): Map[String, Any] = {
        val bestParamsMap = crossModel.getEstimatorParamMaps.zip(crossModel.avgMetrics).maxBy(_._2)._1
        bestParamsMap.toSeq.map(pair => (pair.param.name, pair.value)).toMap
    }


    // 保存模型
    def saveModel(pipeline: PipelineModel, modelPath: String = "model"): Unit = {
        // save到本地货HDFS,供PipelineModel加载
        println(s"pipeline model saving...")
        pipeline.write.overwrite.save(modelPath)
        println(s"pipeline model save success to $modelPath")
    }


    // 评估模型
    def evalModel(evalData: DataFrame, objective: String = "binary"): Unit = {
        val evaluator = {
            if (objective == "binary")
                new BinaryClassificationEvaluator
            else new MulticlassClassificationEvaluator
        }
        val accuracy: Double = evaluator.evaluate(evalData)
        println("accuracy:", accuracy)
    }


    // 更新微调参数
    def updateTuneParams(bestParamsMap: ParamMap): Unit
    def updateTuneParamsFromCV(crossModel: CrossValidatorModel,
                               maxBy: Boolean = true,
                               objective: String = "binary"): Unit


//    // 上传模型
//    def uploadModel(bucket: String,
//                    objectPath: String,
//                    modelPath: String,
//                    suffix: String = ".txt",
//                    endPoint: String = "http://172.13.1.230:9080",
//                    accessKey: String = "admin",
//                    secretKey: String = "12345678"): Unit = {
//        // 上传到minio服务器
//        val minioUtil = new MinioUtil(endPoint, accessKey, secretKey)
//        val file = new File(modelPath)
//        var filePath = ""
//        // 判断是否是文件，是则上传，如果是文件夹，则读取以suffix结尾的文件
//        if (file.isDirectory) {
//            val fileList = file.listFiles().filter(t => t.toString.endsWith(suffix))
//            filePath = fileList.head.getPath
//        } else filePath = file.getPath
//        minioUtil.upload(bucket, objectPath, filePath)
//        println(s"model upload success to ${filePath}")
//    }


}

